{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Nifti Read Example\n",
    "\n",
    "The purpose of this notebook is to illustrate reading Nifti files and iterating over patches of the volumes loaded from them.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/modules/nifti_read_example.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 0.6.0rc1+23.gc6793fd0\n",
      "Numpy version: 1.20.3\n",
      "Pytorch version: 1.9.0a0+c3d40fd\n",
      "MONAI flags: HAS_EXT = True, USE_COMPILED = False\n",
      "MONAI rev id: c6793fd0f316a448778d0047664aaf8c1895fe1c\n",
      "\n",
      "Optional dependencies:\n",
      "Pytorch Ignite version: 0.4.5\n",
      "Nibabel version: 3.2.1\n",
      "scikit-image version: 0.15.0\n",
      "Pillow version: 7.0.0\n",
      "Tensorboard version: 2.5.0\n",
      "gdown version: 3.13.0\n",
      "TorchVision version: 0.10.0a0\n",
      "ITK version: 5.1.2\n",
      "tqdm version: 4.53.0\n",
      "lmdb version: 1.2.1\n",
      "psutil version: 5.8.0\n",
      "pandas version: 1.1.4\n",
      "einops version: 0.3.0\n",
      "\n",
      "For details about installing the optional dependencies, please visit:\n",
      "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Copyright 2020 MONAI Consortium\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "\n",
    "import glob\n",
    "import os\n",
    "import shutil\n",
    "import tempfile\n",
    "\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from monai.config import print_config\n",
    "from monai.data import (\n",
    "    ArrayDataset, GridPatchDataset, create_test_image_3d, PatchIter)\n",
    "from monai.transforms import (\n",
    "    AddChannel,\n",
    "    Compose,\n",
    "    LoadImage,\n",
    "    RandSpatialCrop,\n",
    "    ScaleIntensity,\n",
    "    EnsureType,\n",
    ")\n",
    "from monai.utils import first\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup data directory\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/tmp/tmp3y4h4i5o\n"
     ]
    }
   ],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a number of test Nifti files:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    im, seg = create_test_image_3d(128, 128, 128)\n",
    "\n",
    "    n = nib.Nifti1Image(im, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"im{i}.nii.gz\"))\n",
    "\n",
    "    n = nib.Nifti1Image(seg, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"seg{i}.nii.gz\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a data loader which yields uniform random patches from loaded Nifti files:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 1, 64, 64, 64]) torch.Size([5, 1, 64, 64, 64])\n"
     ]
    }
   ],
   "source": [
    "images = sorted(glob.glob(os.path.join(root_dir, \"im*.nii.gz\")))\n",
    "segs = sorted(glob.glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n",
    "\n",
    "imtrans = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True),\n",
    "        ScaleIntensity(),\n",
    "        AddChannel(),\n",
    "        RandSpatialCrop((64, 64, 64), random_size=False),\n",
    "        EnsureType(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "segtrans = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True),\n",
    "        AddChannel(),\n",
    "        RandSpatialCrop((64, 64, 64), random_size=False),\n",
    "        EnsureType(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "ds = ArrayDataset(images, imtrans, segs, segtrans)\n",
    "\n",
    "loader = torch.utils.data.DataLoader(\n",
    "    ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()\n",
    ")\n",
    "im, seg = first(loader)\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively create a data loader which yields patches in regular grid order from loaded images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image shapes: torch.Size([3, 1, 64, 64, 64]) torch.Size([3, 1, 64, 64, 64])\n",
      "coordinates shapes: torch.Size([3, 4, 2]) torch.Size([3, 4, 2])\n"
     ]
    }
   ],
   "source": [
    "imtrans = Compose([LoadImage(image_only=True),\n",
    "                   ScaleIntensity(), AddChannel(), EnsureType()])\n",
    "\n",
    "segtrans = Compose([LoadImage(image_only=True), AddChannel(), EnsureType()])\n",
    "\n",
    "ds = ArrayDataset(images, imtrans, segs, segtrans)\n",
    "patch_iter = PatchIter(patch_size=(64, 64, 64), start_pos=(0, 0, 0))\n",
    "\n",
    "\n",
    "def img_seg_iter(x):\n",
    "    return (zip(patch_iter(x[0]), patch_iter(x[1])),)\n",
    "\n",
    "\n",
    "ds = GridPatchDataset(ds, img_seg_iter, with_coordinates=False)\n",
    "\n",
    "loader = torch.utils.data.DataLoader(\n",
    "    ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()\n",
    ")\n",
    "im, seg = first(loader)\n",
    "print(\"image shapes:\", im[0].shape, seg[0].shape)\n",
    "print(\"coordinates shapes:\", im[1].shape, seg[1].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup data directory\n",
    "\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
